package depparsing.constraints;

import static util.Array.deepclone;
import constraints.CorpusConstraints;

import gnu.trove.TIntArrayList;

import optimization.stopCriteria.ProjectedGradientL2Norm;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

import optimization.linesearch.NonNewtonInterpolationPickFirstStep;
import optimization.gradientBasedMethods.ProjectedGradientDescent;
import optimization.gradientBasedMethods.stats.ProjectedOptimizerStats;
import optimization.linesearch.GenericPickFirstStep;
import optimization.linesearch.LineSearchMethod;
import optimization.linesearch.WolfRuleLineSearch;
import optimization.stopCriteria.CompositeStopingCriteria;
import optimization.stopCriteria.NormalizedProjectedGradientL2Norm;
import optimization.stopCriteria.NormalizedValueDifference;
import optimization.stopCriteria.StopingCriteria;
import optimization.linesearch.ArmijoLineSearchMinimizationAlongProjectionArc;


import learning.CorpusPR;
import learning.stats.TrainStats;
import model.AbstractCountTable;
import model.AbstractSentenceDist;
import util.MemoryTracker;


import data.WordInstance;
import depparsing.constraints.PCType;
import depparsing.data.DepCorpus;
import depparsing.model.DepModel;
import depparsing.model.DepSentenceDist;

/**
 * 
 * @author kuzman
 * 
 * This represents the dual objective to the L1Lmax penalty.  The primal 
 * objective is:
 *     min   KL(q||p) + \sum_cp Xi_cp
 *     s.t.  Xi_cp <= E_q[f_cpi] for all c,p,i
 * where c is the child item, p is the parent item and i is the index. 
 * For example, to enforce the constraint "each child word is generated by few
 * parent POS tags" you would need c to range over all words, p to range over
 * all POS tags, and there would be a separate i for each possible edge that
 * has child word c and parent POS p. The definition of whether a "parent item" 
 * corresponds to a type (e.g. "Noun") or a particular item (e.g. "the second Noun")
 * is implemented in the {@link GroupedL1LMax} and {@link UngroupedL1LMax}. 
 * 
 * The sentence distributions have edges arranged by sentences in the corpus,
 * but in order to do the simplex projection we need to be able to arrange 
 * them in order of child-type,parent-type,index.  This class stores information 
 * about how to do the reshaping of the parameters when we need to apply the 
 * 
 * 
 */
public abstract class L1LMax implements CorpusConstraints {
	

	final int numChildIds;  // number of types of children (e.g. number of words) 
	final int numParentIds; // number of types of parent (e.g. number of tags)
	final ConstraintEnumerator cstraints;
	final DepCorpus corpus;
	final DepModel model;
	
	/** 
	 * in order to avoid re-allocating lambda, we store it here. Similarly for 
	 * paramsOfP
	 */
	double[] lambda;
	double[][][][] originalChildren;
	double[][] originalRoots;
	
	/** We're going to store the mapping from the sentence, child token index, parent token index to and from 
	 * child type, parent type, edge index in the scp2cpi and cpi2scp arrays.  That way reshaping should not 
	 * require any counting. */
	class SentenceChildParent {public int s,c; public int[] parents; public SentenceChildParent(int s2, int c2, int[] p2){s=s2;c=c2;parents=p2; if(parents == null) throw new AssertionError("parents is null");}}
	final SentenceChildParent[][] edge2scp;  // indexed by type, index
	private final double constraintStrength;
	private TIntArrayList edgesToNotProject;
	// Debugging code -- to make sure we're doing everything correctly in terms of counting. 
	SentenceChildParent[] param2scp;
	
	final double c1= 0.0001, c2=0.9, stoppingPrecision = 1e-5, maxStep = 10;
	final int maxZoomEvals = 10, maxExtrapolationIters = 200;
	int maxProjectionIterations = 200;
	int minOccurrencesForProjection = 0;

    double initialStep = 1000;

	public L1LMax(DepCorpus corpus, DepModel model, ArrayList<WordInstance> toProject, PCType cType, PCType pType, 
			boolean useRoot, boolean useDirection, double constraintStrength, int minOccurrencesForProjection, String fileOfAllowedTypes) throws IOException{

	    System.out.println("L1LMax optimization parameters");
	    System.out.println("c1: " + c1 + " c2: " + c2 + " sc: " + stoppingPrecision + " ms: " + maxStep + 
				" mze: " + maxZoomEvals + " mei: " + maxExtrapolationIters);
	    System.out.println("mProjIter: " + maxProjectionIterations + " minOccur: " + minOccurrencesForProjection);

		this.corpus = corpus;
		this.model = model;
		this.cstraints = new ConstraintEnumerator(corpus, cType, pType, useRoot, useDirection);
		this.constraintStrength = constraintStrength;
		this.minOccurrencesForProjection = minOccurrencesForProjection;
		numChildIds = cstraints.numIdsChild();
		numParentIds = cstraints.numIdsParent();
		ArrayList<Integer> indicesforcp = countIndicesForChildParentType(toProject);
		edge2scp = new SentenceChildParent[indicesforcp.size()][];
		// count how many edge types will not be projected for reporting
		int notToProject = 0;
		// create arrays..
		for (int i = 0; i < edge2scp.length; i++) {
			edge2scp[i] = new SentenceChildParent[indicesforcp.get(i)];
			if (minOccurrencesForProjection > edge2scp[i].length){
				notToProject +=1;
			}
			indicesforcp.set(i,0);
		}
		int totalEdgeTypes = indicesforcp.size();
		System.out.println("Will project "+(totalEdgeTypes-notToProject)+" / "+totalEdgeTypes+" the rest fall below min occurrences to project");

		makeEdge2SentenceChildParent(toProject, indicesforcp);
		// initialize the param2scp 
		int numParams = 0;
		for (int edgeType = 0; edgeType < edge2scp.length; edgeType++) {
			numParams+= edge2scp[edgeType].length;
		}
		param2scp = new SentenceChildParent[numParams];
		int paramIndex = 0;
		for (int edgeType = 0; edgeType < edge2scp.length; edgeType++) {
			for (int index = 0; index < edge2scp[edgeType].length; index++) {
				param2scp[paramIndex++] = edge2scp[edgeType][index];
			}
		}

		// FIXME: edgesToNotProject has not been used for a while; do we want to keep it?
		if (fileOfAllowedTypes != null)
			edgesToNotProject = makeEdgesToNotProject(fileOfAllowedTypes);
		else {
			edgesToNotProject = new TIntArrayList();
		}
	}
	
	/**
	 * count the number of indices necessary for each child-parent type in the ragged array.  For Fernando 
	 * style constraints, this will be the number of 
	 * @param toProject
	 * @return
	 */
	public abstract ArrayList<Integer> countIndicesForChildParentType(ArrayList<WordInstance> toProject);

	public abstract void makeEdge2SentenceChildParent(ArrayList<WordInstance> toProject, ArrayList<Integer> indicesforcp);
	
	private TIntArrayList makeEdgesToNotProject(String fname) throws IOException{
		TIntArrayList res = new TIntArrayList();
		BufferedReader in = new BufferedReader(new FileReader(fname));
		for (String ln = in.readLine(); ln!= null; ln=in.readLine()){
			ln = ln.replaceAll("#.*", "");
			ln = ln.replaceAll(" *$", "");
			if (ln.length() == 0) continue;
			String par = ln.split("  *")[0];
			String child = ln.split("  *")[1];
			int edgeType = cstraints.getEdgeId(child, par, "left");
			if (edgeType < 0) System.out.println("Edge "+par+" -> "+child+" : left doesn't seem to exist, hope that's OK");
			if(!res.contains(edgeType)) res.add(edgeType);
			edgeType = cstraints.getEdgeId(child, par, "right");
			if (edgeType < 0) System.out.println("Edge "+par+" -> "+child+" : right doesn't seem to exist, hope that's OK");
			if(!res.contains(edgeType)) res.add(edgeType);
		}
		in.close();
		return res;
	}
	
	public double getConstraintStrength(int edgeType){
		double myCstrength = this.constraintStrength;
		if (edgesToNotProject.contains(edgeType)) return 0;
		// min occurrences for projection.. FIXME: this didn't help performance, and should be deleted
		if (minOccurrencesForProjection > edge2scp[edgeType].length){
			myCstrength = 0;
		}
		return myCstrength;
	}
	
    

	@SuppressWarnings("unchecked")
	public void project(AbstractCountTable counts,
			AbstractSentenceDist[] posteriors, TrainStats trainStats, CorpusPR pr) {
		MemoryTracker mem  = new MemoryTracker();
		mem.start();
		trainStats.eStepStart(model, pr);
		int numParams = 0;
		for (int i = 0; i < edge2scp.length; i++) {
			numParams += edge2scp[i].length;
		}
		if (numParams!= param2scp.length) throw new AssertionError();
		if (lambda == null){
			lambda = new double[numParams];
			originalChildren = new double[posteriors.length][][][];
			originalRoots = new double[posteriors.length][];
		}
		// FIXME: figure out a way to check that sentences have not changed!
//		if (lambda.value.length != posteriors.length) throw new RuntimeException("num sentences changed!");
//		for (int i = 0; i < posteriors.length; i++) {
//			if (lambda.value[i].length != posteriors[i].depInst.numWords) throw new RuntimeException("sentence "+i+" length changed!");			
//		}
		for (int s = 0; s < posteriors.length; s++) {
			DepSentenceDist sd = (DepSentenceDist) posteriors[s];
			sd.cacheModelAndComputeIO(model.params);
		}
		for (int s = 0; s < posteriors.length; s++) {
			originalChildren[s] =  deepclone(((DepSentenceDist)posteriors[s]).child);
			originalRoots[s] = ((DepSentenceDist)posteriors[s]).root.clone();
		}
		ProjectedOptimizerStats stats = new ProjectedOptimizerStats();
		L1LMaxObjective objective = new L1LMaxObjective(lambda, this, posteriors);
		// objective.doTestGradient = true;
		//GenericPickFirstStep pickFirstStep = new GenericPickFirstStep(1000);
		
		LineSearchMethod linesearch = new ArmijoLineSearchMinimizationAlongProjectionArc(new NonNewtonInterpolationPickFirstStep(initialStep));
//		LineSearchMethod linesearch = new WolfRuleLineSearch(new NonNewtonInterpolationPickFirstStep(initialStep));
		//	LineSearchMethod linesearch = new WolfRuleLineSearch(pickFirstStep, c1, c2, 1000);
		ProjectedGradientDescent optimizer = new ProjectedGradientDescent(linesearch);
		optimizer.setMaxIterations(maxProjectionIterations);
//		GradientAscentProjection optimizer = new GradientAscentProjection(linesearch,stoppingPrecision, maxProjectionIterations);
//      StopingCriteria stopGrad = new NormalizedProjectedGradientL2Norm(stoppingPrecision);
//      StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision);
		// old code: stoppingPrecision*numParams and no stopValue, ever. 
        StopingCriteria stopGrad = new ProjectedGradientL2Norm(stoppingPrecision*numParams);
     // StopingCriteria stopValue = new NormalizedValueDifference(stoppingPrecision);
        CompositeStopingCriteria stop = new CompositeStopingCriteria();
        stop.add(stopGrad);
//        stop.add(stopValue);
        objective.setDebugLevel(3);
        boolean succed = optimizer.optimize(objective, stats,stop);
		// make sure we update the dual params
		objective.getValue();
		counts.clear();
		for (int i = 0; i < posteriors.length; i++) {
			model.addToCounts(posteriors[i], counts);
		}
		mem.finish();
		System.out.println("After  optimization:" + mem.print());
		System.out.println("Suceess " + succed + "/n"+stats.prettyPrint(1));
	}

	public void setMaxProjectionSteps(int tmpProjectItersAtPool) {
		maxProjectionIterations = tmpProjectItersAtPool;
	}

}
